{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Reranker"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Reranker is disigned in cross-encoder architecture that takes the query and text at the same time and directly output their score of similarity. It is more capable on scoring the query-text relevance, but with the tradeoff of slower speed. Thus, complete retrieval system usually contains retrievers in the first stage to do a large scope retrieval, and then follows by rerankers to rerank the results more precisely.\n",
    "\n",
    "In this tutorial, we will go through text retrieval pipeline with reranker and evaluate the results before and after reranking.\n",
    "\n",
    "Note: Step 1-4 are identical to the tutorial of [evaluation](https://github.com/FlagOpen/FlagEmbedding/tree/master/Tutorials/4_Evaluation). We suggest to first go through that if you are not familiar with retrieval."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 0. Setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Install the dependencies in the environment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install -U FlagEmbedding faiss-cpu"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Download and preprocess the MS Marco dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "import numpy as np\n",
    "\n",
    "data = load_dataset(\"namespace-Pt/msmarco\", split=\"dev\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "queries = np.array(data[:100][\"query\"])\n",
    "corpus = sum(data[:5000][\"positive\"], [])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Inference Embeddings: 100%|██████████| 21/21 [01:59<00:00,  5.68s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "shape of the corpus embeddings: (5331, 768)\n",
      "data type of the embeddings:  float32\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from FlagEmbedding import FlagModel\n",
    "\n",
    "# get the BGE embedding model\n",
    "model = FlagModel('BAAI/bge-base-en-v1.5',\n",
    "                  query_instruction_for_retrieval=\"Represent this sentence for searching relevant passages:\",\n",
    "                  use_fp16=True)\n",
    "\n",
    "# get the embedding of the corpus\n",
    "corpus_embeddings = model.encode(corpus)\n",
    "\n",
    "print(\"shape of the corpus embeddings:\", corpus_embeddings.shape)\n",
    "print(\"data type of the embeddings: \", corpus_embeddings.dtype)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Indexing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total number of vectors: 5331\n"
     ]
    }
   ],
   "source": [
    "import faiss\n",
    "\n",
    "# get the length of our embedding vectors, vectors by bge-base-en-v1.5 have length 768\n",
    "dim = corpus_embeddings.shape[-1]\n",
    "\n",
    "# create the faiss index and store the corpus embeddings into the vector space\n",
    "index = faiss.index_factory(dim, 'Flat', faiss.METRIC_INNER_PRODUCT)\n",
    "corpus_embeddings = corpus_embeddings.astype(np.float32)\n",
    "index.train(corpus_embeddings)\n",
    "index.add(corpus_embeddings)\n",
    "\n",
    "print(f\"total number of vectors: {index.ntotal}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Retrieval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "query_embeddings = model.encode_queries(queries)\n",
    "ground_truths = [d[\"positive\"] for d in data]\n",
    "corpus = np.asarray(corpus)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Searching: 100%|██████████| 1/1 [00:00<00:00, 22.35it/s]\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "res_scores, res_ids, res_text = [], [], []\n",
    "query_size = len(query_embeddings)\n",
    "batch_size = 256\n",
    "# The cutoffs we will use during evaluation, and set k to be the maximum of the cutoffs.\n",
    "cut_offs = [1, 10]\n",
    "k = max(cut_offs)\n",
    "\n",
    "for i in tqdm(range(0, query_size, batch_size), desc=\"Searching\"):\n",
    "    q_embedding = query_embeddings[i: min(i+batch_size, query_size)].astype(np.float32)\n",
    "    # search the top k answers for each of the queries\n",
    "    score, idx = index.search(q_embedding, k=k)\n",
    "    res_scores += list(score)\n",
    "    res_ids += list(idx)\n",
    "    res_text += list(corpus[idx])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Reranking"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we will use a reranker to rerank the list of answers we retrieved using our index. Hopefully, this will lead to better results."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following table lists the available BGE rerankers. Feel free to try out to see their differences!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "| Model  | Language |   Parameters   |    Description    |   Base Model     |\n",
    "|:-------|:--------:|:----:|:-----------------:|:--------------------------------------:|\n",
    "| [BAAI/bge-reranker-v2-m3](https://huggingface.co/BAAI/bge-reranker-v2-m3) | Multilingual |     568M     | a lightweight cross-encoder model, possesses strong multilingual capabilities, easy to deploy, with fast inference. | XLM-RoBERTa-Large |\n",
    "| [BAAI/bge-reranker-v2-gemma](https://huggingface.co/BAAI/bge-reranker-v2-gemma) | Multilingual |     2.51B     | a cross-encoder model which is suitable for multilingual contexts, performs well in both English proficiency and multilingual capabilities. | Gemma2-2B |\n",
    "| [BAAI/bge-reranker-v2-minicpm-layerwise](https://huggingface.co/BAAI/bge-reranker-v2-minicpm-layerwise) | Multilingual |    2.72B    | a cross-encoder model which is suitable for multilingual contexts, performs well in both English and Chinese proficiency, allows freedom to select layers for output, facilitating accelerated inference. | MiniCPM |\n",
    "| [BAAI/bge-reranker-v2.5-gemma2-lightweight](https://huggingface.co/BAAI/bge-reranker-v2.5-gemma2-lightweight) | Multilingual |    9.24B    | a cross-encoder model which is suitable for multilingual contexts, performs well in both English and Chinese proficiency, allows freedom to select layers, compress ratio and compress layers for output, facilitating accelerated inference. | Gemma2-9B |\n",
    "| [BAAI/bge-reranker-large](https://huggingface.co/BAAI/bge-reranker-large) |   Chinese and English |     560M     |   a cross-encoder model which is more accurate but less efficient    |  XLM-RoBERTa-Large  |\n",
    "| [BAAI/bge-reranker-base](https://huggingface.co/BAAI/bge-reranker-base)   |   Chinese and English |     278M     |  a cross-encoder model which is more accurate but less efficient     |  XLM-RoBERTa-Base  |"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, let's use a small example to see how reranker works:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[-9.474676132202148, -2.823843240737915, 5.76226806640625]\n"
     ]
    }
   ],
   "source": [
    "from FlagEmbedding import FlagReranker\n",
    "\n",
    "reranker = FlagReranker('BAAI/bge-reranker-large', use_fp16=True) \n",
    "# Setting use_fp16 to True speeds up computation with a slight performance degradation\n",
    "\n",
    "# use the compute_score() function to calculate scores for each input sentence pair\n",
    "scores = reranker.compute_score([\n",
    "    ['what is panda?', 'Today is a sunny day'], \n",
    "    ['what is panda?', 'The tiger (Panthera tigris) is a member of the genus Panthera and the largest living cat species native to Asia.'],\n",
    "    ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']\n",
    "    ])\n",
    "print(scores)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, let's use the reranker to rerank our previously retrieved results:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_ids, new_scores, new_text = [], [], []\n",
    "for i in range(len(queries)):\n",
    "    # get the new scores of the previously retrieved results\n",
    "    new_score = reranker.compute_score([[queries[i], text] for text in res_text[i]])\n",
    "    # sort the lists of ids and scores by the new scores\n",
    "    new_id = [tup[1] for tup in sorted(list(zip(new_score, res_ids[i])), reverse=True)]\n",
    "    new_scores.append(sorted(new_score, reverse=True))\n",
    "    new_ids.append(new_id)\n",
    "    new_text.append(corpus[new_id])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Evaluate"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For details of these metrics, please checkout the tutorial of [evaluation](https://github.com/FlagOpen/FlagEmbedding/tree/master/Tutorials/4_Evaluation)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6.1 Recall"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_recall(preds, truths, cutoffs):\n",
    "    recalls = np.zeros(len(cutoffs))\n",
    "    for text, truth in zip(preds, truths):\n",
    "        for i, c in enumerate(cutoffs):\n",
    "            recall = np.intersect1d(truth, text[:c])\n",
    "            recalls[i] += len(recall) / max(min(len(recall), len(truth)), 1)\n",
    "    recalls /= len(preds)\n",
    "    return recalls"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before reranking:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "recall@1:\t0.97\n",
      "recall@10:\t1.0\n"
     ]
    }
   ],
   "source": [
    "recalls_init = calc_recall(res_text, ground_truths, cut_offs)\n",
    "for i, c in enumerate(cut_offs):\n",
    "    print(f\"recall@{c}:\\t{recalls_init[i]}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After reranking:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "recall@1:\t0.99\n",
      "recall@10:\t1.0\n"
     ]
    }
   ],
   "source": [
    "recalls_rerank = calc_recall(new_text, ground_truths, cut_offs)\n",
    "for i, c in enumerate(cut_offs):\n",
    "    print(f\"recall@{c}:\\t{recalls_rerank[i]}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6.2 MRR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "def MRR(preds, truth, cutoffs):\n",
    "    mrr = [0 for _ in range(len(cutoffs))]\n",
    "    for pred, t in zip(preds, truth):\n",
    "        for i, c in enumerate(cutoffs):\n",
    "            for j, p in enumerate(pred):\n",
    "                if j < c and p in t:\n",
    "                    mrr[i] += 1/(j+1)\n",
    "                    break\n",
    "    mrr = [k/len(preds) for k in mrr]\n",
    "    return mrr"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before reranking:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MRR@1:\t0.97\n",
      "MRR@10:\t0.9825\n"
     ]
    }
   ],
   "source": [
    "mrr_init = MRR(res_text, ground_truths, cut_offs)\n",
    "for i, c in enumerate(cut_offs):\n",
    "    print(f\"MRR@{c}:\\t{mrr_init[i]}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After reranking:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MRR@1:\t0.99\n",
      "MRR@10:\t0.995\n"
     ]
    }
   ],
   "source": [
    "mrr_rerank = MRR(new_text, ground_truths, cut_offs)\n",
    "for i, c in enumerate(cut_offs):\n",
    "    print(f\"MRR@{c}:\\t{mrr_rerank[i]}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6.3 nDCG"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before reranking:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "nDCG@1: 0.97\n",
      "nDCG@10: 0.9869253606521631\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import ndcg_score\n",
    "\n",
    "pred_hard_encodings = []\n",
    "for pred, label in zip(res_text, ground_truths):\n",
    "    pred_hard_encoding = list(np.isin(pred, label).astype(int))\n",
    "    pred_hard_encodings.append(pred_hard_encoding)\n",
    "\n",
    "for i, c in enumerate(cut_offs):\n",
    "    nDCG = ndcg_score(pred_hard_encodings, res_scores, k=c)\n",
    "    print(f\"nDCG@{c}: {nDCG}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After reranking:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "nDCG@1: 0.99\n",
      "nDCG@10: 0.9963092975357145\n"
     ]
    }
   ],
   "source": [
    "pred_hard_encodings_rerank = []\n",
    "for pred, label in zip(new_text, ground_truths):\n",
    "    pred_hard_encoding = list(np.isin(pred, label).astype(int))\n",
    "    pred_hard_encodings_rerank.append(pred_hard_encoding)\n",
    "\n",
    "for i, c in enumerate(cut_offs):\n",
    "    nDCG = ndcg_score(pred_hard_encodings_rerank, new_scores, k=c)\n",
    "    print(f\"nDCG@{c}: {nDCG}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
